{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training neural network with DALI and Paxml\n",
    "\n",
    "This simple example shows how to train a neural network implemented in Paxml with DALI data preprocessing. It builds on MNIST training example from Paxml codebse that can be found [here](https://github.com/google/paxml/blob/paxml-v1.1.0/paxml/tasks/vision/params/mnist.py).\n",
    "\n",
    "We use MNIST in Caffe2 format from [DALI_extra](https://github.com/NVIDIA/DALI_extra) as a data source."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "training_data_path = os.path.join(\n",
    "    os.environ[\"DALI_EXTRA_PATH\"], \"db/MNIST/training/\"\n",
    ")\n",
    "validation_data_path = os.path.join(\n",
    "    os.environ[\"DALI_EXTRA_PATH\"], \"db/MNIST/testing/\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The first step is to create the iterator definition function that will later be used to create instances of DALI iterators. It defines all steps of the preprocessing. In this simple example we have `fn.readers.caffe2` for reading data in Caffe2 format, `fn.decoders.image` for image decoding, `fn.crop_mirror_normalize` used to normalize the images and `fn.reshape` to adjust the shape of the output tensors.\n",
    "\n",
    "This example focuses on how to use DALI pipeline with Paxml. For more information on writing DALI iterators look into [DALI and JAX getting started](jax-getting_started.ipynb) and [pipeline documentation](../../../pipeline.rst). To learn more about Paxml and how to write neural networks with it, look into [Paxml Github page](https://github.com/google/paxml)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import nvidia.dali.fn as fn\n",
    "import nvidia.dali.types as types\n",
    "from nvidia.dali.plugin.jax import data_iterator\n",
    "\n",
    "\n",
    "@data_iterator(\n",
    "    output_map=[\"inputs\", \"labels\"],\n",
    "    reader_name=\"mnist_caffe2_reader\",\n",
    "    auto_reset=True,\n",
    ")\n",
    "def mnist_iterator(data_path, random_shuffle):\n",
    "    jpegs, labels = fn.readers.caffe2(\n",
    "        path=data_path,\n",
    "        random_shuffle=random_shuffle,\n",
    "        name=\"mnist_caffe2_reader\",\n",
    "    )\n",
    "    images = fn.decoders.image(jpegs, device=\"mixed\", output_type=types.GRAY)\n",
    "    images = fn.crop_mirror_normalize(\n",
    "        images, dtype=types.FLOAT, std=[255.0], output_layout=\"HWC\"\n",
    "    )\n",
    "\n",
    "    labels = labels.gpu()\n",
    "    labels = fn.reshape(labels, shape=[])\n",
    "\n",
    "    return images, labels"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This example uses Pax data input defined in Praxis. We will create a simple wrapper that uses DALI iterator for JAX as a data source."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from praxis import base_input\n",
    "from nvidia.dali.plugin import jax as dax\n",
    "\n",
    "\n",
    "class MnistDaliInput(base_input.BaseInput):\n",
    "    def __post_init__(self):\n",
    "        super().__post_init__()\n",
    "\n",
    "        data_path = (\n",
    "            training_data_path if self.is_training else validation_data_path\n",
    "        )\n",
    "\n",
    "        training_pipeline = mnist_iterator(\n",
    "            data_path=data_path,\n",
    "            random_shuffle=self.is_training,\n",
    "            batch_size=self.batch_size,\n",
    "        )\n",
    "        self._iterator = dax.DALIGenericIterator(\n",
    "            training_pipeline,\n",
    "            output_map=[\"inputs\", \"labels\"],\n",
    "            reader_name=\"mnist_caffe2_reader\",\n",
    "            auto_reset=True,\n",
    "        )\n",
    "\n",
    "    def get_next(self):\n",
    "        try:\n",
    "            return next(self._iterator)\n",
    "        except StopIteration:\n",
    "            self._iterator.reset()\n",
    "            return next(self._iterator)\n",
    "\n",
    "    def reset(self) -> None:\n",
    "        super().reset()\n",
    "        self._iterator = self._iterator.reset()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`MnistDaliInput` can be used in Pax `Experiment` as a source of data. Code sample below shows how these two classes can be connected by defining `datasets` method of `Experiment` class.\n",
    "\n",
    "```python\n",
    "  def datasets(self) -> list[pax_fiddle.Config[base_input.BaseInput]]:\n",
    "    return [\n",
    "        pax_fiddle.Config(\n",
    "            MnistDaliInput, batch_size=self.BATCH_SIZE, is_training=True\n",
    "        )\n",
    "    ]\n",
    "```\n",
    "\n",
    "For the full working example you can look into [docs/examples/frameworks/jax/pax_examples](https://github.com/NVIDIA/DALI/tree/main/docs/examples/frameworks/jax/pax_examples). Code in this folder can be tested by running command below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "!python -m paxml.main --job_log_dir=/tmp/dali_pax_logs --exp pax_examples.dali_pax_example.MnistExperiment 2>/dev/null"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It produces a log compatible with tensorboard in /tmp/dali_pax_logs. We use a helper function that reads training accuracy from the logs and prints it in the terminal."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.core.util import event_pb2\n",
    "from tensorflow.python.lib.io import tf_record\n",
    "from tensorflow.python.framework import tensor_util\n",
    "\n",
    "\n",
    "def print_logs(path):\n",
    "    \"Helper function to print logs from logs directory created by paxml example\"\n",
    "\n",
    "    def summary_iterator():\n",
    "        for r in tf_record.tf_record_iterator(path):\n",
    "            yield event_pb2.Event.FromString(r)\n",
    "\n",
    "    for summary in summary_iterator():\n",
    "        for value in summary.summary.value:\n",
    "            if value.tag == \"Metrics/accuracy\":\n",
    "                t = tensor_util.MakeNdarray(value.tensor)\n",
    "                print(f\"Iteration: {summary.step}, accuracy: {t}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With this helper function we can print the accuracy of the training inside Python code."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration: 100, accuracy: 0.3935546875\n",
      "Iteration: 200, accuracy: 0.5634765625\n",
      "Iteration: 300, accuracy: 0.728515625\n",
      "Iteration: 400, accuracy: 0.8369140625\n",
      "Iteration: 500, accuracy: 0.87109375\n",
      "Iteration: 600, accuracy: 0.87890625\n",
      "Iteration: 700, accuracy: 0.884765625\n",
      "Iteration: 800, accuracy: 0.8994140625\n",
      "Iteration: 900, accuracy: 0.8994140625\n",
      "Iteration: 1000, accuracy: 0.90625\n"
     ]
    }
   ],
   "source": [
    "for file in os.listdir(\"/tmp/dali_pax_logs/summaries/train/\"):\n",
    "    print_logs(os.path.join(\"/tmp/dali_pax_logs/summaries/train/\", file))"
   ]
  }
 ],
 "metadata": {
  "celltoolbar": "Raw Cell Format",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
